梦想不会自己发光,真正闪耀的是那个为梦狂奔的你。献给知行的孩子们!(Eric.He著)
分治算法(Divide and Conquer)不仅在查找问题中展现出高效优势,在数值计算领域同样是核心算法思想。数值计算的核心需求是对大规模数据进行精准、高效的运算(如矩阵乘法、信号变换等),传统算法往往面临时间复杂度高、运算效率低的问题。分治思想通过将大规模数值计算问题拆解为若干小规模子问题,并行或递归求解后合并结果,能显著降低时间复杂度。本文将聚焦两个经典实例——矩阵乘法的Strassen算法、傅里叶变换的快速傅里叶变换(FFT)算法,详细讲解分治思想在数值计算中的应用。
分治算法解决数值计算问题的通用步骤:
矩阵乘法是线性代数、机器学习、图像处理等领域的基础运算(如神经网络中的权重更新、图像的卷积运算)。传统矩阵乘法算法(三重循环)的时间复杂度为O(n³),对于大规模矩阵(如1000×1000及以上),运算效率极低。Strassen算法基于分治思想,通过优化子矩阵的合并策略,将时间复杂度降至O(n^log7)≈O(n².81),大幅提升了大规模矩阵乘法的运算效率。
问题描述:
算法解析:
将矩阵A、B、C分别拆分为4个(n/2)×(n/2)的子矩阵,拆分规则如下:
A = [[A₁₁, A₁₂], [A₂₁, A₂₂]],B = [[B₁₁, B₁₂], [B₂₁, B₂₂]],C = [[C₁₁, C₁₂], [C₂₁, C₂₂]]
Strassen提出通过构造7个中间矩阵(仅需7次(n/2)×(n/2)矩阵乘法),替代传统的8次乘法,中间矩阵定义如下:
通过7次乘法计算出M₁~M₇后,再通过8次矩阵加法得到子矩阵C₁₁~C₂₂:
将计算得到的子矩阵C₁₁、C₁₂、C₂₁、C₂₂按原拆分规则拼接,得到最终的n×n矩阵C=A×B。
算法步骤:
给定4×4矩阵A和B,计算C=A×B,步骤如下:
代码包含矩阵拆分、矩阵加减、Strassen算法核心函数,以及矩阵扩展(处理非2的幂次方规模矩阵)功能,使用vector容器存储矩阵,保证代码的灵活性和可读性。
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
// 定义矩阵类型:vector<vector<int>>(此处使用int,若需处理浮点数可改为double)
typedef vector<vector<int>> Matrix;
// 矩阵加法:两个矩阵A和B相加,返回新矩阵
Matrix matrixAdd(const Matrix& A, const Matrix& B) {
int n = A.size();
Matrix res(n, vector<int>(n, 0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
res[i][j] = A[i][j] + B[i][j];
}
}
return res;
}
// 矩阵减法:两个矩阵A和B相减,返回新矩阵
Matrix matrixSub(const Matrix& A, const Matrix& B) {
int n = A.size();
Matrix res(n, vector<int>(n, 0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
res[i][j] = A[i][j] - B[i][j];
}
}
return res;
}
// 矩阵拆分:将n×n矩阵拆分为4个(n/2)×(n/2)子矩阵
void matrixSplit(const Matrix& A, Matrix& A11, Matrix& A12, Matrix& A21, Matrix& A22) {
int n = A.size();
int mid = n / 2;
A11.resize(mid, vector<int>(mid));
A12.resize(mid, vector<int>(mid));
A21.resize(mid, vector<int>(mid));
A22.resize(mid, vector<int>(mid));
for (int i = 0; i < mid; ++i) {
for (int j = 0; j < mid; ++j) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + mid];
A21[i][j] = A[i + mid][j];
A22[i][j] = A[i + mid][j + mid];
}
}
}
// 矩阵合并:将4个(n/2)×(n/2)子矩阵合并为n×n矩阵
Matrix matrixMerge(const Matrix& A11, const Matrix& A12, const Matrix& A21, const Matrix& A22) {
int mid = A11.size();
int n = mid * 2;
Matrix res(n, vector<int>(n, 0));
for (int i = 0; i < mid; ++i) {
for (int j = 0; j < mid; ++j) {
res[i][j] = A11[i][j];
res[i][j + mid] = A12[i][j];
res[i + mid][j] = A21[i][j];
res[i + mid][j + mid] = A22[i][j];
}
}
return res;
}
// 传统矩阵乘法:用于小规模子矩阵计算,避免递归过深
Matrix traditionalMult(const Matrix& A, const Matrix& B) {
int n = A.size();
int m = B[0].size();
int p = B.size();
Matrix res(n, vector<int>(m, 0));
for (int i = 0; i < n; ++i) {
for (int k = 0; k < p; ++k) {
if (A[i][k] == 0) continue; // 剪枝:跳过0元素,提升效率
for (int j = 0; j < m; ++j) {
res[i][j] += A[i][k] * B[k][j];
}
}
}
return res;
}
// Strassen算法核心函数
Matrix strassenMult(const Matrix& A, const Matrix& B) {
int n = A.size();
// 递归终止条件:当矩阵规模小于等于2时,使用传统乘法
if (n <= 2) {
return traditionalMult(A, B);
}
// 1. 分解矩阵
Matrix A11, A12, A21, A22;
Matrix B11, B12, B21, B22;
matrixSplit(A, A11, A12, A21, A22);
matrixSplit(B, B11, B12, B21, B22);
// 2. 计算7个中间矩阵M1-M7
Matrix M1 = strassenMult(matrixAdd(A11, A22), matrixAdd(B11, B22));
Matrix M2 = strassenMult(matrixAdd(A21, A22), B11);
Matrix M3 = strassenMult(A11, matrixSub(B12, B22));
Matrix M4 = strassenMult(A22, matrixSub(B21, B11));
Matrix M5 = strassenMult(matrixAdd(A11, A12), B22);
Matrix M6 = strassenMult(matrixSub(A21, A11), matrixAdd(B11, B12));
Matrix M7 = strassenMult(matrixSub(A12, A22), matrixAdd(B21, B22));
// 3. 计算子矩阵C11-C22
Matrix C11 = matrixAdd(matrixSub(matrixAdd(M1, M4), M5), M7);
Matrix C12 = matrixAdd(M3, M5);
Matrix C21 = matrixAdd(M2, M4);
Matrix C22 = matrixAdd(matrixSub(matrixAdd(M1, M3), M2), M6);
// 4. 合并子矩阵,返回结果
return matrixMerge(C11, C12, C21, C22);
}
// 矩阵扩展:将矩阵扩展为2的幂次方规模(补0)
Matrix matrixExpand(const Matrix& A, int targetSize) {
int n = A.size();
Matrix res(targetSize, vector<int>(targetSize, 0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
res[i][j] = A[i][j];
}
}
return res;
}
// 对外接口:处理任意规模的方阵乘法
Matrix strassenMatrixMult(const Matrix& A, const Matrix& B) {
// 检查输入矩阵是否为方阵且维度一致
if (A.size() != A[0].size() || B.size() != B[0].size() || A.size() != B.size()) {
throw invalid_argument("输入必须是维度一致的方阵");
}
int n = A.size();
// 计算最小的2的幂次方,使其大于等于n
int targetSize = 1;
while (targetSize < n) {
targetSize <<= 1; // 等价于targetSize *= 2
}
// 扩展矩阵(若需)
Matrix AExpanded = matrixExpand(A, targetSize);
Matrix BExpanded = matrixExpand(B, targetSize);
// 执行Strassen乘法
Matrix CExpanded = strassenMult(AExpanded, BExpanded);
// 截取结果矩阵的前n×n部分(去除扩展的0元素)
Matrix C(n, vector<int>(n, 0));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
C[i][j] = CExpanded[i][j];
}
}
return C;
}
// 打印矩阵
void printMatrix(const Matrix& mat) {
for (const auto& row : mat) {
for (int val : row) {
cout << val << " ";
}
cout << endl;
}
}
// 测试案例
int main() {
// 4×4矩阵A和B
Matrix A = {
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12},
{13, 14, 15, 16}
};
Matrix B = {
{17, 18, 19, 20},
{21, 22, 23, 24},
{25, 26, 27, 28},
{29, 30, 31, 32}
};
try {
// 使用Strassen算法计算
Matrix CStrassen = strassenMatrixMult(A, B);
// 输出结果
cout << "Strassen算法结果:" << endl;
printMatrix(CStrassen);
} catch (const exception& e) {
cout << "错误:" << e.what() << endl;
return 1;
}
return 0;
}
傅里叶变换(Fourier Transform)是数字信号处理、图像处理、通信等领域的核心技术,用于将时域离散信号转换为频域离散信号;,从而分析信号的频率成分(如声音的音调、图像的纹理)。传统离散傅里叶变换(DFT)的时间复杂度为O(n²),对于大规模信号(如高清图像、长时音频)处理效率极低。快速傅里叶变换(FFT)基于分治思想,通过利用复数单位根的对称性和周期性,将时间复杂度降至O(n log n),是傅里叶变换工程应用的基础。
核心概念:
离散傅里叶变换(DFT)的定义:
算法解析:
将长度为n的信号x[j]按索引的奇偶性拆分为两个长度为n/2的子信号:
根据DFT定义,原信号的DFT结果可拆分为偶序列和奇序列的DFT结果的组合:
其中X₀[k]为偶序列x₀[j]的DFT结果,X₁[k]为奇序列x₁[j]的DFT结果。
递归求解偶序列x₀[j]和奇序列x₁[j]的DFT结果(即X₀[k]和X₁[k])。当子信号长度为1时,DFT结果等于信号本身(递归终止条件)。
利用单位根的周期性(Wₙ^k = Wₙ/2^(k mod (n/2)))和对称性(Wₙ^(k+n/2) = -Wₙ^k),将X₀[k]和X₁[k]组合为原信号的DFT结果X[k]和X[k+n/2],完成合并。
算法步骤:
给定长度为8的离散信号x = [x0, x1, x2, x3, x4, x5, x6, x7],计算其FFT结果X,步骤如下:
代码实现按时间抽取的FFT算法,使用复数容器存储信号和变换结果,包含信号扩展、FFT核心计算、逆FFT(用于信号还原验证)及信号打印功能。
#include <iostream>
#include <vector>
#include <complex>
#include <cmath>
#include <iomanip>
using namespace std;
typedef complex<double> Complex;
const double PI = acos(-1.0);
// 信号扩展:将信号长度扩展为2的幂次方(零填充)
vector<Complex> signalExpand(const vector<Complex>& x, int targetSize) {
vector res(targetSize, 0);
for (int i = 0; i < x.size(); ++i) {
res[i] = x[i];
}
return res;
}
// FFT核心函数(按时间抽取,递归实现)
void fftRecursive(vector<Complex>& x) {
int n = x.size();
// 递归终止条件:信号长度为1时,无需变换
if (n == 1) return;
// 1. 分解:按奇偶索引拆分信号
vector<Complex> even(n / 2), odd(n / 2);
for (int i = 0; 2 * i < n; ++i) {
even[i] = x[2 * i];
odd[i] = x[2 * i + 1];
}
// 递归处理子信号
fftRecursive(even);
fftRecursive(odd);
// 2. 合并:计算FFT结果
for (int k = 0; 2 * k < n; ++k) {
// 计算n次单位根 W_n^k = e^(-2πik/n)
Complex W = exp(Complex(0, -2 * PI * k / n));
x[k] = even[k] + W * odd[k];
x[k + n / 2] = even[k] - W * odd[k];
}
}
// 逆FFT(用于信号还原,验证FFT正确性)
void ifftRecursive(vector<Complex>& X) {
int n = X.size();
if (n == 1) return;
// 逆FFT的单位根为 W_n^k = e^(2πik/n)(正号)
vector<Complex> even(n / 2), odd(n / 2);
for (int i = 0; 2 * i < n; ++i) {
even[i] = X[2 * i];
odd[i] = X[2 * i + 1];
}
ifftRecursive(even);
ifftRecursive(odd);
for (int k = 0; 2 * k < n; ++k) {
Complex W = exp(Complex(0, 2 * PI * k / n));
X[k] = (even[k] + W * odd[k]) / (double)n; // 逆变换需除以n
X[k + n / 2] = (even[k] - W * odd[k]) / (double)n;
}
}
// 对外接口:处理任意长度的信号,返回FFT结果
vector<Complex> fft(vector<Complex> x) {
int n = x.size();
// 计算最小的2的幂次方,使其大于等于n
int targetSize = 1;
while (targetSize < n) {
targetSize <<= 1;
}
// 扩展信号
x = signalExpand(x, targetSize);
// 执行FFT
fftRecursive(x);
return x;
}
// 打印信号(实部,保留4位小数)
void printSignal(const vector<Complex>& x, int nSamples) {
for (int i = 0; i < nSamples; ++i) {
cout << fixed << setprecision(4) << x[i].real() << " ";
if ((i + 1) % 10 == 0) cout << endl; // 每10个数据换行
}
cout << endl;
}
// 测试案例:生成正弦信号,验证FFT与逆FFT
int main() {
// 生成信号:2Hz + 5Hz的正弦波混合信号(采样频率100Hz,采样点数64)
int fs = 100; // 采样频率
int nSamples = 64; // 采样点数
vector<Complex> x(nSamples);
for (int i = 0; i < nSamples; ++i) {
double t = (double)i / fs; // 时间轴
// 混合正弦信号:sin(2π*2t) + sin(2π*5t)
double val = sin(2 * PI * 2 * t) + sin(2 * PI * 5 * t);
x[i] = Complex(val, 0);
}
cout << "原始信号(前64个点,实部):" << endl;
printSignal(x, nSamples);
// 1. 执行FFT
vector<Complex> X = fft(x);
int fftSize = X.size();
// 计算频率轴(仅取前半部分,因为FFT结果对称)
vector<Complex> freq(fftSize / 2);
for (int k = 0; k < fftSize / 2; ++k) {
freq[k] = (double)k * fs / fftSize;
}
// 计算幅度谱(FFT结果的模,归一化)
vector<Complex> amplitude(fftSize / 2);
for (int k = 0; k < fftSize / 2; ++k) {
amplitude[k] = abs(X[k]) / fftSize * 2;
}
// 2. 执行逆FFT,还原信号
vector<Complex> xRecovered = X; // 复制FFT结果用于逆变换
ifftRecursive(xRecovered);
// 3. 输出结果验证
cout << "\nFFT频率轴(前20个频率,单位:Hz):" << endl;
for (int k = 0; k < 20; ++k) {
cout << fixed << setprecision(2) << freq[k] << " ";
}
cout << endl;
cout << "\nFFT幅度谱(前20个频率的幅度):" << endl;
for (int k = 0; k < 20; ++k) {
cout << fixed << setprecision(4) << amplitude[k] << " ";
}
cout << endl;
// 计算原始信号与还原信号的误差
double maxError = 0.0;
cout << "\n原始信号与还原信号的误差(前10个点):" << endl;
for (int i = 0; i < 10; ++i) {
double orig = x[i].real();
double rec = xRecovered[i].real();
double error = abs(orig - rec);
if (error > maxError) maxError = error;
cout << fixed << setprecision(4) << "原始:" << orig << ", 还原:" << rec << ", 误差:" << error << endl;
}
cout << "\n最大还原误差:" << fixed << setprecision(6) << maxError << endl;
cout << "还原是否成功:" << (maxError < 1e-10 ? "是" : "否") << endl;
return 0;
}
| 对比维度 | Strassen算法(矩阵乘法) | FFT算法(傅里叶变换) |
|---|---|---|
| 分治核心 | 将大矩阵拆分为子矩阵,通过优化中间矩阵减少乘法运算次数 | 将信号按奇偶索引拆分,利用单位根的对称性和周期性减少运算次数 |
| 时间复杂度 | O(n^log7)≈O(n².81) | O(n log n) |
| 分解方式 | 按矩阵维度均匀拆分(n×n→4个(n/2)×(n/2)) | 按信号索引奇偶性拆分(n点→2个n/2点) |
| 合并关键 | 通过中间矩阵的加减组合得到最终结果,合并逻辑复杂 | 利用单位根的对称性组合子信号FFT结果,合并逻辑为“蝴蝶操作” |
| 应用领域 | 线性代数、机器学习、图像处理(矩阵运算相关) | 信号处理、通信、图像处理(频谱分析相关) |
| 核心挑战 | 中间加减运算带来的精度损失和额外开销 | 频谱泄漏、零填充选择、复数运算的解读 |
学习建议:先理解两种算法的分治拆分逻辑(矩阵拆分、信号奇偶拆分),再深入研究合并步骤的数学原理(Strassen中间矩阵组合、FFT单位根特性);通过对比传统算法与分治算法的效率差异,体会分治思想的优化价值;结合实际应用场景(如矩阵乘法的神经网络应用、FFT的音频频谱分析),加深对算法的理解与应用能力。